import argparse
import numpy
import torch
import time
import utils

# Parse arguments

parser = argparse.ArgumentParser()
parser.add_argument("--env", required=True,
                    help="name of the environment to be run (REQUIRED)")
parser.add_argument("--model", required=True,
                    help="name of the trained model (REQUIRED)")
parser.add_argument("--seed", type=int, default=0,
                    help="random seed (default: 0)")
parser.add_argument("--pause", type=float, default=0.1,
                    help="pause duration between two consequent actions of the agent (default: 0.1)")
parser.add_argument("--episodes", type=int, default=10,
                    help="number of episodes to visualize")
parser.add_argument("--no-render", action="store_true", default=False)
parser.add_argument("--use-mem", action="store_true", default=False)
parser.add_argument("--use-mem-detector", action="store_true", default=False)
parser.add_argument("--no-rm", action="store_true", default=False,
                    help="The agent is ignorant of any RM states.")
parser.add_argument("--rm-update-algo", type=str, default="rm_detector",
                    help="[rm_detector, rm_threshold, event_threshold, independent_belief, perfect_rm]")

args = parser.parse_args()

# Set seed for all randomness sources
utils.seed(args.seed)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}\n")

# Load environment
env = utils.make_env(args.env, args.rm_update_algo, args.seed)
env.reset()
print("Environment loaded\n")

# Load agent
# model_dir = utils.get_model_dir(args.model)
model_dir = args.model

agent = utils.Agent(env, env.observation_space, env.action_space, model_dir, args.rm_update_algo, use_mem=args.use_mem, use_mem_detector=args.use_mem_detector, no_rm=args.no_rm, device=device)
print("Agent loaded\n")

# Create a window to view the environment
if not args.no_render:
    print(env.unwrapped)
    env.unwrapped.render()

episode_returns = []

for episode in range(args.episodes):
    returnn = 0
    obs = env.reset()

    i = 0
    while True:
        if not args.no_render:
            env.unwrapped.render()
            time.sleep(0.2)

        action = agent.get_action(obs, display_rm_belief_as_mission=True)
        obs, reward, done, info = env.step(action)
        agent.analyze_feedback(done)
        returnn += reward 

        i += 1

        if done:
            print("Episode %i --- Return: %.3f --- Num Steps: %d" %(episode+1, returnn, i))
            episode_returns.append(returnn)
            break

print("Average return:", numpy.mean(episode_returns), "Std:", numpy.std(episode_returns))
